Skip to content

Pass input_output_alias to TritonAutotunedKernelCall#2814

Open
tdophung wants to merge 3 commits intoNVIDIA:mainfrom
tdophung:sort_chunks_WAR_p2
Open

Pass input_output_alias to TritonAutotunedKernelCall#2814
tdophung wants to merge 3 commits intoNVIDIA:mainfrom
tdophung:sort_chunks_WAR_p2

Conversation

@tdophung
Copy link
Copy Markdown
Collaborator

@tdophung tdophung commented Mar 31, 2026

Description

https://nvbugspro.nvidia.com/bug/5810384
To remove the WAR that was put in place for this bug.

This should also serves as part 2 to WAR to the intermittent sort_chunks_by_index bug seen before in #2730

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

if JAX version >= 0.9.3, which contains the fix to restore all aliased input buffers that was saved away during autotuning: we pass the input_output_alias tuples to TritonAutotunedKernelCall

If JAX version < 0.9.3, which does not contain the fix, we pass an empty dict to the call.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: JAX Toolbox <jax@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Mar 31, 2026

Greptile Summary

This PR removes a workaround (WAR) that previously suppressed input_output_aliases from being passed to TritonAutotunedKernelCall. The WAR was needed because jaxlib/gpu/triton_kernels.cc had a bug where the autotuning restore phase iterated over all declared aliases unconditionally, while input_copies only contained entries for aliases where XLA actually shared buffers at runtime — accessing a missing entry produced a null vector whose .data() returned nullptr, causing CUDA_ERROR_INVALID_VALUE.

The fix in jax-ml/jax#35218 (merged 2026-03-17) corrects this, and ships in JAX 0.9.3. The PR gates the new behavior behind a version check (TRITON_AUTOTUNED_INPUT_OUTPUT_ALIAS_MIN_JAX_VERSION = \"0.9.3\"): on older JAX the empty-tuple WAR is preserved (now with an explanatory UserWarning), while on JAX >= 0.9.3 the aliases are correctly assembled as (input_idx, num_inputs + output_idx, size_bytes) tuples and forwarded.

Key changes:

  • version_utils.py: adds TRITON_AUTOTUNED_INPUT_OUTPUT_ALIAS_MIN_JAX_VERSION = \"0.9.3\" with a detailed comment and exports it.
  • utils.py: replaces the hardcoded () for input_output_aliases_with_sizes with a version-gated build path; emits a UserWarning on older JAX when aliases are requested but cannot be safely passed.

Confidence Score: 5/5

Safe to merge — the WAR is preserved for all currently-released JAX versions, and the new path is correctly gated behind the upstream fix version.

No P0 or P1 issues found. The alias tuple construction (input_idx, num_inputs + output_idx, size_bytes) is correct. The version gate correctly keeps the old WAR active on JAX < 0.9.3, so there is no regression risk on currently available JAX releases. The UserWarning is informative and correctly stacked. The upstream fix reference (jax-ml/jax#35218) is well-documented.

No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/jax/version_utils.py Adds TRITON_AUTOTUNED_INPUT_OUTPUT_ALIAS_MIN_JAX_VERSION = "0.9.3" constant with a clear comment linking to the upstream JAX fix, and exports it in all. No issues found.
transformer_engine/jax/triton_extensions/utils.py Replaces the unconditional empty-tuple WAR with a version-gated path: on JAX >= 0.9.3 the alias list is built correctly (input_idx, num_inputs + output_idx, size_bytes); on older JAX a UserWarning is emitted and the WAR is preserved. Logic looks correct.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["triton_call_lowering() called\nwith input_output_aliases"] --> B{is_autotuned?}
    B -- No --> C["TritonKernelCall\n(aliases via ffi_lowering only)"]
    B -- Yes --> D{input_output_aliases\nis truthy?}
    D -- No --> E["input_output_aliases_with_sizes = ()"]
    D -- Yes --> F{JAX >=\n0.9.3?}
    F -- Yes --> G["Build alias tuples:\n(input_idx,\n num_inputs + output_idx,\n size_bytes)"]
    G --> H["TritonAutotunedKernelCall\nwith aliases ✓"]
    F -- No --> I["UserWarning emitted\ninput_output_aliases_with_sizes = () WAR"]
    I --> J["TritonAutotunedKernelCall\nwith empty aliases (safe WAR)"]
    E --> J
Loading

Reviews (3): Last reviewed commit: "Add jax version guard for the input_outp..." | Re-trigger Greptile

Signed-off-by: tdophung <tdophung@nvidia.com>
Copy link
Copy Markdown
Collaborator

@jberchtold-nvidia jberchtold-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM pending CI, thanks! Good idea to consolidate all the Triton+JAX version requirements in a single place!

@tdophung
Copy link
Copy Markdown
Collaborator Author

/te-ci jax

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants